In the blog post "Asymptotic Estimation of Weight RMS in AdamW (Part 1)", we derived the asymptotic expression for the RMS of model weights trained with AdamW. However, we assumed fixed Weight Decay and learning rate throughout training, which does not fully align with practical training. Therefore, in this article, we extend the previous conclusions to a dynamic version.
The dynamic version allows both Weight Decay and learning rate to vary with increasing training steps, such as classic Cosine Decay, WSD (Warmup Stable Decay), etc., making the conclusions more general.
Step One#
Our starting point remains the definition of AdamW:
(1)
\[
\text{Adam}\color{skyblue}{\text{W}}:=\left\{\begin{aligned}
&\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\
&\boldsymbol{v}_t = \beta_2 \boldsymbol{v}_{t-1} + \left(1 - \beta_2\right) \boldsymbol{g}_t^2\\
&\hat{\boldsymbol{m}}_t = \boldsymbol{m}_t\left/\left(1 - \beta_1^t\right)\right.\\
&\hat{\boldsymbol{v}}_t = \boldsymbol{v}_t\left/\left(1 - \beta_2^t\right)\right.\\
&\boldsymbol{u}_t =\hat{\boldsymbol{m}}_t\left/\left(\sqrt{\hat{\boldsymbol{v}}_t} + \epsilon\right)\right.\\
&\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{skyblue}{ + \lambda_t \boldsymbol{\theta}_{t-1}})
\end{aligned}\right.
\]
Since \(\eta_t\lambda_t\ll 1\), we can write:
(2)
\(\boldsymbol{\theta}_t = (1 - \eta_t\lambda_t)\boldsymbol{\theta}_{t-1} -\eta_t\boldsymbol{u}_t \approx e^{- \eta_t\lambda_t}\boldsymbol{\theta}_{t-1} -\eta_t\boldsymbol{u}_t\)
Let \(\kappa_t = \sum_{i=1}^t \eta_i\lambda_i\), then direct expansion gives:
(3)
\(\boldsymbol{\theta}_t \approx e^{-\kappa_t}\boldsymbol{\theta}_0 - \sum_{i=1}^t e^{-(\kappa_t - \kappa_i)}\eta_i\boldsymbol{u}_i = e^{-\kappa_t}\left(\boldsymbol{\theta}_0 - \sum_{i=1}^t e^{\kappa_i}\eta_i\boldsymbol{u}_i\right)\)
Then set \(z_t = \sum_{i=1}^t e^{\kappa_i}\eta_i\), and by mean-field approximation:
(4)
\[
\bar{\boldsymbol{u}}_t\triangleq\frac{1}{z_t}\sum_{i=1}^t e^{\kappa_i}\eta_i \boldsymbol{u}_i = \frac{1}{z_t}\sum_{i=1}^t e^{\kappa_i}\eta_i \frac{\boldsymbol{m}_i}{\sqrt{\boldsymbol{v}_i}}\approx \frac{\bar{\boldsymbol{m}}_t \triangleq \frac{1}{z_t}\sum_{i=1}^t e^{\kappa_i}\eta_i\boldsymbol{m}_i}{\sqrt{\bar{\boldsymbol{v}}_t \triangleq \frac{1}{z_t}\sum_{i=1}^t e^{\kappa_i}\eta_i\boldsymbol{v}_i}}
\]
Thus:
(5)
\(\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0 - z_t \bar{\boldsymbol{u}}_t\Vert_{RMS}^2 \approx e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + e^{-2\kappa_t}z_t^2\Vert\bar{\boldsymbol{u}}_t\Vert_{RMS}^2\)
Step Two#
Following the previous approach, to estimate \(\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2\), we need to assume \(\boldsymbol{g}_j\) are i.i.d. from \(\mathcal{N}(\boldsymbol{\mu},\boldsymbol{\sigma}^2)\), then compute:
(6)
\(\mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \mathbb{E}\left[\frac{\bar{\boldsymbol{m}}_t^2}{\bar{\boldsymbol{v}}_t}\right] \approx \frac{\mathbb{E}[\bar{\boldsymbol{m}}_t^2]}{\mathbb{E}[\bar{\boldsymbol{v}}_t]}\)
Finally, averaging over components of \(\mathbb{E}[\bar{\boldsymbol{u}}_t^2]\) yields an approximation for \(\Vert\bar{\boldsymbol{u}}_t\Vert_{RMS}^2\).
Expanding \(\boldsymbol{m}_t,\boldsymbol{v}_t\) gives:
(7)
\(\boldsymbol{m}_t = (1 - \beta_1)\sum_{i=1}^t \beta_1^{t-i}\boldsymbol{g}_i,\qquad \boldsymbol{v}_t = (1 - \beta_2)\sum_{i=1}^t \beta_2^{t-i}\boldsymbol{g}_i^2\)
We also have the identity:
(8)
\(\sum_{i=1}^t \sum_{j=1}^i a_i b_j = \sum_{j=1}^t \sum_{i=j}^t a_i b_j\)
Using these results, we can write:
(9)
\[
\begin{aligned}
\bar{\boldsymbol{m}}_t &= \frac{1}{z_t}\sum_{i=1}^t e^{\kappa_i}\eta_i\boldsymbol{m}_i = \frac{1 - \beta_1}{z_t}\sum_{i=1}^t e^{\kappa_i}\eta_i\sum_{j=1}^i \beta_1^{i-j}\boldsymbol{g}_j = \sum_{j=1}^t\boldsymbol{g}_j\underbrace{\frac{1 - \beta_1}{z_t}\sum_{i=j}^t e^{\kappa_i}\beta_1^{i-j}\eta_i}_{\text{denoted as }\bar{\beta}_1(j,t)} \\
\bar{\boldsymbol{v}}_t &= \frac{1}{z_t}\sum_{i=1}^t e^{\kappa_i}\eta_i\boldsymbol{v}_i = \frac{1 - \beta_2}{z_t}\sum_{i=1}^t e^{\kappa_i}\eta_i\sum_{j=1}^i \beta_2^{i-j}\boldsymbol{g}_j^2 = \sum_{j=1}^t\boldsymbol{g}_j^2\underbrace{\frac{1 - \beta_2}{z_t}\sum_{i=j}^t e^{\kappa_i}\beta_2^{i-j}\eta_i}_{\text{denoted as }\bar{\beta}_2(j,t)}
\end{aligned}
\]
Step Three#
First compute the denominator. When \(t\) is sufficiently large (\(\beta_1^t, \beta_2^t\) sufficiently small), \(\sum_{j=1}^t \bar{\beta}_1(j,t)\) and \(\sum_{j=1}^t \bar{\beta}_2(j,t)\) are sufficiently close to 1, so:
(10)
\(\mathbb{E}[\bar{\boldsymbol{v}}_t] = \sum_{j=1}^t\bar{\beta}_2(j,t) \mathbb{E}[\boldsymbol{g}_j^2] = \sum_{j=1}^t\bar{\beta}_2(j,t) (\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \approx \boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2\)
Similarly, \(\mathbb{E}[\bar{\boldsymbol{m}}_t] = \boldsymbol{\mu}\), and \(\mathbb{E}[\bar{\boldsymbol{m}}_t^2] = \mathbb{E}[\bar{\boldsymbol{m}}_t]^2 + \mathbb{V}ar[\bar{\boldsymbol{m}}_t]\). Using additivity of variance:
(11)
\(\mathbb{V}ar[\bar{\boldsymbol{m}}_t] = \sum_{j=1}^t\bar{\beta}_1(j,t)^2 \mathbb{V}ar[\boldsymbol{g}_j] = \sum_{j=1}^t\bar{\beta}_1(j,t)^2 \boldsymbol{\sigma}^2\)
Thus:
(12)
\(\mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \frac{\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2\sum_{j=1}^t\bar{\beta}_1(j,t)^2}{\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2}\)
And:
(13)
\(\Vert\bar{\boldsymbol{u}}_t\Vert_{RMS}^2 \approx \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + \sum_{j=1}^t\bar{\beta}_1(j,t)^2}{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1}\)
Finally:
(14)
\(\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + e^{-2\kappa_t}z_t^2\frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + \sum_{j=1}^t\bar{\beta}_1(j,t)^2}{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1}\)
If readers are viewing this article directly, some steps may appear to jump; in such cases, it is advisable to revisit "Asymptotic Estimation of Weight RMS in AdamW (Part 1)" to familiarize with the reasoning behind each approximation.
Example One#
First consider \(\boldsymbol{\mu}=\boldsymbol{0}\), substituting the expression for \(\bar{\beta}_1(j,t)\) into the above equation yields:
(15)
\(\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + e^{-2\kappa_t}(1-\beta_1)^2\sum_{j=1}^t\left(\sum_{i=j}^t e^{\kappa_i}\beta_1^{i-j}\eta_i\right)^2\)
Now consider the simple case \(\lambda_t=0\), i.e., no Weight Decay, then:
(16)
\(\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_1)^2\sum_{j=1}^t\left(\sum_{i=j}^t \beta_1^{i-j}\eta_i\right)^2\)
If \(\beta_1\to 0\), we immediately get \(\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + \sum_{j=1}^t\eta_j^2\). This indicates that without Weight Decay and as training steps \(t\to\infty\), to prevent Weight RMS from exploding, the sum of squares of the learning rate sequence must converge.
In fact, even for \(0 < \beta_1 < 1\), this condition remains necessary and sufficient:
(17)
\(\sum_{j=1}^{\infty}\left(\sum_{i=j}^{\infty} \beta_1^{i-j}\eta_i\right)^2 < \infty \qquad\Leftrightarrow\qquad \sum_{j=1}^{\infty}\eta_j^2 < \infty\)
The proof is not difficult. Transform the left side:
(18)
\[
\begin{aligned}
\sum_{j=1}^{\infty}\left(\sum_{i=j}^{\infty} \beta_1^{i-j}\eta_i\right)^2 = \sum_{j=1}^{\infty}\left(\sum_{i=0}^{\infty} \beta_1^i\eta_{i+j}\right)^2 =&\, \sum_{j=1}^{\infty}\left(\sum_{i_1=0}^{\infty} \beta_1^{i_1}\eta_{i_1+j}\right)\left(\sum_{i_2=0}^{\infty} \beta_1^{i_2}\eta_{i_2+j}\right) \\
=&\, \sum_{i_1=0}^{\infty}\sum_{i_2=0}^{\infty} \beta_1^{i_1 + i_2}\sum_{j=1}^{\infty}\eta_{i_1+j}\eta_{i_2+j}
\end{aligned}
\]
This shows that if the left side converges, then for all \(i_1, i_2\), the sum \(\sum_{j=1}^{\infty}\eta_{i_1+j}\eta_{i_2+j}\) converges, which naturally implies \(\sum_{j=1}^{\infty}\eta_j^2\) converges, proving necessity. For sufficiency, starting from the above and using Cauchy-Schwarz inequality:
(19)
\[
\begin{aligned}
\sum_{i_1=0}^{\infty}\sum_{i_2=0}^{\infty} \beta_1^{i_1 + i_2}\sum_{j=1}^{\infty}\eta_{i_1+j}\eta_{i_2+j} \leq&\, \sum_{i_1=0}^{\infty}\sum_{i_2=0}^{\infty} \beta_1^{i_1 + i_2}\sqrt{\left(\sum_{j=1}^{\infty}\eta_{i_1+j}^2\right)\left(\sum_{j=1}^{\infty}\eta_{i_2+j}^2\right)} \\
\leq&\, \sum_{i_1=0}^{\infty}\sum_{i_2=0}^{\infty} \beta_1^{i_1 + i_2}\sqrt{\left(\sum_{j=1}^{\infty}\eta_j^2\right)\left(\sum_{j=1}^{\infty}\eta_j^2\right)} \\
=&\, \frac{1}{(1-\beta_1)^2} \sum_{j=1}^{\infty}\eta_j^2
\end{aligned}
\]
Thus convergence of \(\sum_{j=1}^{\infty}\eta_j^2\) implies convergence of the left side, proving sufficiency.
Example Two#
Next, we consider the case where Weight Decay is constant but learning rate varies, i.e., \(\kappa_t = \lambda\sum_{i=1}^t \eta_i\). If we want infinite training to obtain a solution as close as possible to theoretical optimum, the learning rate should satisfy \(\sum_{i=1}^{\infty} \eta_i \to \infty\).
For the general case, computing Equation (15) is challenging, but we can consider further approximations based on practical scenarios. In actual training, typically \(\lambda_t \eta_t \ll 1\), so \(e^{\kappa_i}\) grows much slower than \(\beta_1^i\) decays. Thus, we can approximate:
(20)
\(\sum_{i=j}^t e^{\kappa_i}\beta_1^{i-j}\eta_i \approx \sum_{i=j}^t e^{\kappa_j}\beta_1^{i-j}\eta_j = e^{\kappa_j}\eta_j\sum_{i=j}^t\beta_1^{i-j}\approx e^{\kappa_j}\eta_j\sum_{i=j}^{\infty}\beta_1^{i-j} = \frac{e^{\kappa_j}\eta_j}{1-\beta_1}\)
Substituting this approximation back into Equation (15):
(21)
\(\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + e^{- 2\kappa_t}\sum_{j=1}^t e^{2\kappa_j}\eta_j^2\)
Now we can only compute specifically for given \(\eta_j\). For example, when \(\lambda_j,\eta_j\) are constants, we have \(\kappa_t = \lambda\eta t\), and:
(22)
\[
\begin{aligned}
\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx&\, e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + e^{- 2\kappa_t}\sum_{j=1}^t e^{2\kappa_j}\eta_j^2 \\
=&\, e^{-2\lambda\eta t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + e^{-2\lambda\eta t}\sum_{j=1}^t e^{2\lambda\eta j}\eta^2 \\
=&\, e^{-2\lambda\eta t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + \frac{e^{2\lambda\eta}(1 - e^{-2\lambda\eta t})}{e^{2\lambda\eta} - 1}\eta^2 \\
\approx&\, e^{-2\lambda\eta t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1 - e^{-2\lambda\eta t} )\frac{\eta}{2\lambda}
\end{aligned}
\]
This matches the result from the previous article.
Differential Equation#
Considering integrals are usually easier to compute than summations, we can approximate sums with integrals:
(23)
\(\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + e^{- 2\kappa_t}\sum_{j=1}^t e^{2\kappa_j}\eta_j^2\approx e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + e^{- 2\kappa_t}\int_0^t e^{2\kappa_s}\eta_s^2 ds\)
where \(\kappa_t = \int_0^t \lambda_s\eta_s ds\). Let \(\rho_t = \Vert\boldsymbol{\theta}_t\Vert_{RMS}^2\), multiply both sides by \(e^{2\kappa_t}\) and differentiate:
(24)
\(\frac{d}{dt}\rho_t \approx -2\lambda_t\eta_t\rho_t + \eta_t^2\)
This is the differential equation governing the RMS squared. If \(\rho_t\) converges to a constant as \(t\to\infty\), the left side equals 0, giving:
(25)
\(\lim_{t\to\infty} \rho_t \approx \lim_{t\to\infty} \frac{\eta_t}{2\lambda_t}\)
This tells us that for decay-type learning rate schedules, the final learning rate should not be set to 0; otherwise, prolonged training risks model weight collapse.
Mean Field#
The scenario where \(t\to\infty\) typically applies to multi-epoch supervised training. In pre-training scenarios, training is usually single-epoch, and \(\kappa_t\) is often \(\mathcal{\Theta}(1)\).
Under the assumption \(\kappa_t=\mathcal{\Theta}(1)\), we can consider mean-field approximation. Starting from the integral form (23), by definition, \(\kappa_s\) is a monotonically increasing function from 0 to \(\kappa_t\) over \([0,t]\). Thus:
(26)
\[
e^{- 2\kappa_t}\int_0^t \eta_s^2 ds \leq e^{- 2\kappa_t}\int_0^t e^{2\kappa_s}\eta_s^2 ds = \int_0^t e^{2\kappa_s- 2\kappa_t}\eta_s^2 ds \leq \int_0^t \eta_s^2 ds
\]
That is, the target integral itself is bounded between \(e^{- 2\kappa_t} \nu_t\) and \(\nu_t\), where \(\nu_t = \int_0^t \eta_s^2 ds\). When \(\kappa_t=\mathcal{\Theta}(1)\), \(e^{-2\kappa_t}\) is not much smaller than 1, meaning \(\nu_t\) itself may be a good approximation.
We can be more precise by estimating a reasonable multiplier:
(27)
\[
e^{- 2\kappa_t}\int_0^t e^{2\kappa_s}\eta_s^2 ds \approx \frac{\nu_t e^{- 2\kappa_t}}{t}\int_0^t e^{2\kappa_s} ds \approx \frac{\nu_t e^{- 2\kappa_t}}{t}\int_0^t e^{2(\kappa_t/t)s} ds = \frac{\nu_t}{2\kappa_t}(1 - e^{- 2\kappa_t})
\]
Substituting into Equation (23):
(28)
\(\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx e^{-2\kappa_t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1 - e^{- 2\kappa_t})\frac{\nu_t}{2\kappa_t}\)
Example Three#
Returning to our main concern—the common setting of "fixed Weight Decay, variable learning rate"—we compute \(\kappa_t,\nu_t\) for several specific examples.
First, linear learning rate:
(29)
\(\eta_s = \eta_a + (\eta_b - \eta_a) s / t\)
Integrating:
(30)
\[
\begin{aligned}
\kappa_t &= \int_0^t \lambda\eta_s ds= \lambda (\eta_a + \eta_b) t / 2 \\
\nu_t &= \int_0^t \eta_s^2 ds = (\eta_a^2 + \eta_a \eta_b + \eta_b^2) t / 3
\end{aligned}
\]
Next, Cosine Decay:
(31)
\(\eta_s = \eta_{\min} + (\eta_{\max} - \eta_{\min})\left(\frac{1}{2} + \frac{1}{2}\cos \frac{s\pi}{t}\right)\)
Integrating:
(32)
\[
\begin{aligned}
\kappa_t &= \int_0^t \lambda\eta_s ds= \lambda (\eta_{\min} + \eta_{\max}) t / 2 \\
\nu_t &= \int_0^t \eta_s^2 ds = (3\eta_{\min}^2 + 2\eta_{\min} \eta_{\max} + 3\eta_{\max}^2 ) t / 8
\end{aligned}
\]
Finally, WSD (Warmup Stable Decay):
(33)
\[
\eta_s = \left\{\begin{aligned} \frac{s}{t_1}\eta_{\max}, \quad s \in [0, t_1] \\
\eta_{\max} , \quad s \in [t_1, t_2] \\
\frac{t-s}{t-t_2}\eta_{\max}, \quad j \in [t_2, t]
\end{aligned}\right.
\]
We have:
(34)
\[
\begin{aligned}
\kappa_t &= \int_0^t \lambda\eta_s ds= \lambda \eta_{\max} (t + t_2 - t_1) / 2 \\
\nu_t &= \int_0^t \eta_s^2 ds = \eta_{\max}^2 (t + 2t_2 - 2t_1) / 3
\end{aligned}
\]
Simulation#
We can also validate the above approximations through numerical simulation:
import numpy as np
N, T = 10000, 10000
beta1, beta2 = 0.9, 0.95
m, v = 0, 0
w = np.random.randn(N) * (init_std := 0.1)
lr_max, lr_min, wd = 0.001, 0.0001, 0.1
lr = lr_min + (lr_max - lr_min) * (1 + np.cos(np.arange(T) / T * np.pi)) / 2
for i in range(T):
g = np.random.randn(N)
m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g**2
w = w - lr[i] * (m / v**0.5 + wd * w)
# Direct computation ≈ 0.0744
weight_rms = (w**2).mean()**0.5
# Series approximation ≈ 0.0742
kappa = wd * lr.cumsum()
approx1 = ((np.exp(kappa * 2) * lr**2).sum() + init_std**2)**0.5 * np.exp(-kappa[-1])
# Mean-field approximation ≈ 0.0760
kappa = wd * (lr_max + lr_min) / 2 * T
nu = (3 * lr_max**2 + 2 * lr_max * lr_min + 3 * lr_min**2) / 8 * T
approx2 = ((np.exp(kappa * 2) - 1) * nu / kappa / 2 + init_std**2)**0.5 * np.exp(-kappa)
print(weight_rms)
print(approx1)
print(approx2)
Summary#
This article extends the results from Part 1 to a dynamic version, allowing us to estimate Weight RMS for AdamW with time-varying learning rates and Weight Decay.
Original Article: Su Jianlin. Asymptotic Estimation of Weight RMS in AdamW (Part 2). Scientific Spaces.
How to cite this translation:
Su, J. Asymptotic Estimation of Weight RMS in AdamW (Part 2) [Translated by Juanxi Tian].
Scientific Spaces.
BibTeX:
@article{su2025adamw_weight_rms_2,
title = {Asymptotic Estimation of Weight RMS in AdamW (Part 2)},
author = {Su, Jianlin},
journal = {Scientific Spaces},
year = {2025},
url = {https://kexue.fm/archives/11404},
note = {Translated by Juanxi Tian (ScalingOpt Team)}
}